library(topicdict)
library(purrr)
library(quanteda)
library(tibble)
library(ggplot2)
library(dplyr)
library(topicmodels)Create model function:
create_model <- function(docs, seed_list, extra_k){
set.seed(225)
names(seed_list) <- 1:length(seed_list)
dict <- quanteda::dictionary(seed_list)
model <- topicdict_model(docs,
dict = dict, extra_k = extra_k,
remove_numbers = FALSE,
remove_punct = TRUE,
remove_symbols = TRUE,
remove_separators = TRUE)
return(model)
}# Check dispersion
tidy_seededlda_out <- function(model, res, n=15, show=F){
# Create a nested data frame which contains W and Z
post <- topicdict::posterior(res)
topwords <- top_terms(post, n=n)
topwords <- data.frame(topwords)
colnames(topwords) <- paste0("EstTopic", 1:ncol(topwords))
topwords %>%
tidyr::gather(., key=EstTopic, value=Word) %>%
mutate(Word = gsub("\\s.*$", "", Word)) -> otidy
if(show){
num_seededtopic <- length(model$seeds)
print(top_terms(post, n)[, 1:num_seededtopic])
}
return(otidy)
}
list_to_tibble <- function(lobj){
# Flatten list and get a tibble
obj_len <- lobj %>% map(length) %>% flatten_int()
element <- lobj %>% flatten_chr()
tibble(SeedTopic = rep(paste0("EstTopic", 1:length(obj_len)), obj_len),
Word=element
) -> res
return(res)
}
count_appearence_list <- function(otidy, lobj){
all_words <- lobj %>% flatten_chr()
SeedTopicName <- paste0("EstTopic", 1:length(lobj))
otidy %>%
right_join(., list_to_tibble(get("lobj")), by="Word") %>%
mutate(count = ifelse(is.na(EstTopic), 0, 1)) %>%
group_by(Word, SeedTopic) %>%
summarize(count = sum(count)) -> organized
organized %>%
ggplot(., aes(x=factor(count))) +
geom_histogram(stat="count") +
xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
theme_bw(base_size=15) +
theme(plot.title = element_text(hjust = 0.5)) -> g1
organized %>%
ggplot(., aes(x=factor(count))) +
geom_histogram(stat="count") +
facet_wrap(~ SeedTopic, ncol=3) +
xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
theme_bw(base_size=15) +
theme(plot.title = element_text(hjust = 0.5)) -> g2
# Only check topics with seed
otidy %>%
filter(EstTopic %in% get("SeedTopicName")) %>%
right_join(., list_to_tibble(get("lobj")), by="Word") %>%
mutate(count = ifelse(is.na(EstTopic), 0, 1)) %>%
group_by(Word, SeedTopic) %>%
summarize(count = sum(count)) -> organized_
organized_ %>%
ggplot(., aes(x=factor(count))) +
geom_histogram(stat="count") +
facet_wrap(~ SeedTopic, ncol=3) +
xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics\n(Only topics with keywords)") +
theme_bw(base_size=15) +
theme(plot.title = element_text(hjust = 0.5)) -> g3
return(list(g1, g2, g3))
}library(tm)
get_lda_result <- function(doc_folder, seed_list, iter_num, k, topicvec=1:k, show_n=15){
# Prepare Data
corpus <- Corpus(DirSource(doc_folder))
strsplit_space_tokenizer <- function(x)
unlist(strsplit(as.character(x), "[[:space:]]+"))
dtm <- DocumentTermMatrix(corpus,
control = list(tokenize=strsplit_space_tokenizer,
stopwords = F, tolower = T,
stemming = F, wordLengths = c(1, Inf)))
lda <- LDA(dtm, k = k, control = list(seed = 225, iter=iter_num), method="Gibbs")
all_words <- seed_list %>% flatten_chr()
tidytext::tidy(lda) %>%
group_by(topic) %>%
top_n(show_n, beta) %>%
rename(topics = topic, Word = term) %>%
select(-beta) %>%
right_join(., list_to_tibble(get("seed_list")), by="Word") %>%
mutate(count = ifelse(is.na(topics), 0, 1)) %>%
group_by(Word, SeedTopic) %>%
summarize(count = sum(count)) -> organized
organized %>%
ggplot(., aes(x=factor(count))) +
geom_histogram(stat="count") +
xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
theme_bw(base_size=15) +
theme(plot.title = element_text(hjust = 0.5)) -> g1
organized %>%
ggplot(., aes(x=factor(count))) +
geom_histogram(stat="count") +
facet_wrap(~ SeedTopic, ncol=3) +
xlab("Number of Topics") + ylab("Count") + ggtitle("Words Split across Topics") +
theme_bw(base_size=15) +
theme(plot.title = element_text(hjust = 0.5)) -> g2
return(list(g1, g2))
}data_folder <- tempfile()
seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"), D=1000, K=15, TotalV=3000, alpha=0.1, beta_r=0.1, beta_s=0.1, p=c(rep(0.2, 5),rep(0.12, 5),rep(0.05, 5)), lambda=200, seeds_len=5)[1] "Finished: "
user system elapsed
9.428 0.943 10.735
seed_list <- lapply(seed_list, function(x){tolower(x)})
seed_list_full <- seed_list
seed_list <- seed_list_full[c(3,4,7,8,14,15)]doc_folder <- paste0(data_folder, "Sim1", "/W")
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
remove_numbers = FALSE, # For simulation, make it false
remove_punct = TRUE,
remove_symbols = TRUE,
remove_separators = TRUE)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=9)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=15)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=19)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=25)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=44)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=50)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=94)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=100)[[1]]
[[2]]
seed_list <- seed_list_full[c(3,4,5)]
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=22)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=25)[[1]]
[[2]]
seed_list <- seed_list_full[c(11,14,15)]
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=22)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=25)[[1]]
[[2]]
seed_list <- seed_list_full[c(6,5,7,9,12)]
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=25)[[1]]
[[2]]
seed_list <- list(
c(seed_list_full[[6]][1:4],seed_list_full[[13]][2]),
c(seed_list_full[[5]][1:4],seed_list_full[[7]][2]),
c(seed_list_full[[7]][1:4], seed_list_full[[3]][2]),
c(seed_list_full[[9]][1:4], seed_list_full[[5]][2]),
c(seed_list_full[[12]][1:4], seed_list_full[[8]][2])
)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=25)[[1]]
[[2]]
seed_list <- list(
c(seed_list_full[[6]][1:5]),
c(seed_list_full[[5]][1:4],seed_list_full[[7]][2]),
c(seed_list_full[[3]][1:3], seed_list_full[[3]][2], seed_list_full[[5]][2]),
c(seed_list_full[[9]][1:3], seed_list_full[[5]][3], seed_list_full[[7]][3]),
c(seed_list_full[[12]][1:3], seed_list_full[[8]][2], seed_list_full[[6]][2])
)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=20)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
data_folder <- tempfile()
seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"), D=1000, K=45, TotalV=3000, alpha=0.1, beta_r=0.1, beta_s=0.1, p=c(rep(0.2, 15),rep(0.12, 15),rep(0.05, 15)), lambda=200, seeds_len=5)[1] "Finished: "
user system elapsed
7.968 1.137 9.199
seed_list <- lapply(seed_list, function(x){tolower(x)})
seed_list_full <- seed_list
seed_list <- seed_list_full[c(5,12,19,24,36,40)]doc_folder <- paste0(data_folder, "Sim1", "/W")
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
remove_numbers = FALSE, # For simulation, make it false
remove_punct = TRUE,
remove_symbols = TRUE,
remove_separators = TRUE)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=44)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=50)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=74)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=80)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=94)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=100)[[1]]
[[2]]
seed_list <- list(
c(seed_list_full[[5]][1:4],seed_list_full[[45]][2]),
c(seed_list_full[[12]][1:4],seed_list_full[[33]][2]),
c(seed_list_full[[19]][1:4], seed_list_full[[16]][2]),
c(seed_list_full[[24]][1:4], seed_list_full[[21]][2]),
c(seed_list_full[[36]][1:4], seed_list_full[[9]][2])
)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=55)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=60)[[1]]
[[2]]
seed_list <- list(
c(seed_list_full[[5]][1:5]),
c(seed_list_full[[12]][1:4],seed_list_full[[9]][2]),
c(seed_list_full[[19]][1:3], seed_list_full[[22]][2], seed_list_full[[40]][2]),
c(seed_list_full[[24]][1:3], seed_list_full[[33]][3], seed_list_full[[28]][3]),
c(seed_list_full[[36]][1:3], seed_list_full[[12]][2], seed_list_full[[16]][2])
)
explore_$visualize_dict_prop(seed_list)model <- create_model(docs, seed_list, extra_k=55)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res), seed_list)[[1]]
[[2]]
[[3]]
doc_folder <- paste0("/Users/Shusei/Dropbox/Study/My_Research/TreeStructuredTopicModel/Papers/replication/Catalinac/data/docs") # Data from original data Document-Term Matrix
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
explore_ <- explore(docs,
remove_numbers = TRUE,
remove_punct = TRUE,
remove_symbols = TRUE,
remove_separators = TRUE)# Remove overlapping words for better comparison
seed_list <- list(c("農業 整備 漁業 開発 水産"), # agriculture, fishing industry
c("税 消費 暮らし 景気 税金"), # tax
c("介護 高齢 保険 長寿 健康"), # aging society
c("軍事 戦争 軍 自衛隊 平和"))
seed_list <- lapply(seed_list, function(x){strsplit(x, " ")[[1]]})
g <- explore_$visualize_dict_prop(seed_list)
gggsave("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/Catalinac.pdf", g, width=5, height=4, family="Japan1GothicBBB")# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=1)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4
[1,] "社会" "党" "年金" "党"
[2,] "整備 [✓]" "日本" "制度" "政治"
[3,] "教育" "憲法" "改革" "日本"
[4,] "推進" "増税" "地域" "国民"
[5,] "図る" "政治" "安心" "税 [2]"
[6,] "作り" "消費 [✓]" "医療" "消費 [2]"
[7,] "実現" "守る" "円" "共産"
[8,] "地域" "税 [✓]" "実現" "自民党"
[9,] "充実" "国民" "地方" "企業"
[10,] "福祉" "社会" "日本" "守る"
[11,] "対策" "暮らし [✓]" "介護 [✓]" "反対"
[12,] "産業" "自民党" "社会" "選挙"
[13,] "振興" "円" "支援" "平和 [✓]"
[14,] "豊か" "保障" "ひと" "廃止"
[15,] "政治" "平和 [4]" "国" "民主"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=6)[[1]]
[[2]]
# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=16)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4
[1,] "整備 [✓]" "消費 [✓]" "制度" "党"
[2,] "道路" "税 [✓]" "年金" "日本"
[3,] "地域" "憲法" "医療" "共産"
[4,] "産業" "党" "介護 [✓]" "国民"
[5,] "振興" "増税" "支援" "政治"
[6,] "推進" "日本" "実現" "税 [2]"
[7,] "道" "平和 [4]" "充実" "消費 [2]"
[8,] "促進" "暮らし [✓]" "雇用" "増税"
[9,] "建設" "守る" "負担" "反対"
[10,] "県" "政治" "安心" "企業"
[11,] "交通" "社会" "教育" "民主"
[12,] "実現" "アメリカ" "保険 [✓]" "守る"
[13,] "図る" "改悪" "対策" "選挙"
[14,] "早期" "保障" "社会" "つらぬく"
[15,] "都市" "企業" "企業" "基地"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=20)[[1]]
[[2]]
# SeededLDA eight keywords
model <- create_model(docs, seed_list, extra_k=46)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4
[1,] "整備 [✓]" "党" "制度" "憲法"
[2,] "振興" "共産" "医療" "消費 [2]"
[3,] "道路" "日本" "年金" "税 [2]"
[4,] "道" "政治" "介護 [✓]" "平和 [✓]"
[5,] "農業 [✓]" "国民" "支援" "アメリカ"
[6,] "地域" "税 [✓]" "保険 [✓]" "改悪"
[7,] "産業" "消費 [✓]" "負担" "企業"
[8,] "図る" "増税" "雇用" "社会"
[9,] "農林" "民主" "充実" "保障"
[10,] "促進" "暮らし [✓]" "安心" "守る"
[11,] "推進" "自民党" "実現" "年金"
[12,] "県" "税金 [✓]" "保育" "増税"
[13,] "建設" "やめる" "拡充" "戦争 [✓]"
[14,] "交通" "選挙" "費" "反対"
[15,] "水産 [✓]" "目指す" "保障" "雇用"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=50)[[1]]
[[2]]
# Remove overlapping words for better comparison
topic52 <- c("農業 産業 整備 漁業 開発")
topic62 <- c("復興 連立 被災 災害 ひと")
topic63 <- c("政治 主義 自由 社会 民主")
topic20 <- c("税 消費 廃止 国民 日本")
topic58 <- c("企業 教育 中小 充実 図る")
seed_list <- list(topic52, topic62, topic63, topic20, topic58)
seed_list <- lapply(seed_list, function(x){strsplit(x, " ")[[1]]})
g <- explore_$visualize_dict_prop(seed_list)
gggsave("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/Catalinac2.pdf", g, width=5, height=4, family="Japan1GothicBBB")model <- create_model(docs, seed_list, extra_k=1)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4 5
[1,] "整備 [✓]" "円" "政治 [✓]" "党" "教育 [✓]"
[2,] "政治 [3]" "年金" "改革" "日本 [✓]" "税 [4]"
[3,] "社会 [3]" "制度" "国民 [4]" "国民 [✓]" "政治 [3]"
[4,] "推進" "医療" "日本 [4]" "共産" "福祉"
[5,] "地域" "政権" "選挙" "政治 [3]" "守る"
[6,] "図る [5]" "兆" "自民党" "増税" "平和"
[7,] "作り" "郵政" "党" "税 [✓]" "実現"
[8,] "豊か" "無駄" "新しい" "消費 [✓]" "消費 [4]"
[9,] "振興" "民営" "政権" "憲法" "充実 [✓]"
[10,] "実現" "実現" "社会 [✓]" "守る" "円"
[11,] "産業 [✓]" "廃止 [4]" "実現" "自民党" "社会 [3]"
[12,] "福祉" "金" "腐敗" "企業 [5]" "中小 [✓]"
[13,] "発展" "ひと [✓]" "民主 [✓]" "反対" "減税"
[14,] "対策" "税金" "ひと [2]" "暮らし" "企業 [✓]"
[15,] "国際" "財源" "世界" "民主 [3]" "年金"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=6)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=10)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4 5
[1,] "整備 [✓]" "年金" "政治 [✓]" "党" "教育 [✓]"
[2,] "産業 [✓]" "円" "自民党" "日本 [✓]" "福祉"
[3,] "推進" "政権" "国民 [4]" "消費 [✓]" "図る [✓]"
[4,] "地域" "制度" "改革" "共産" "社会 [3]"
[5,] "振興" "無駄" "党" "国民 [✓]" "充実 [✓]"
[6,] "作り" "地域" "選挙" "税 [✓]" "企業 [✓]"
[7,] "道路" "交代" "金" "政治 [3]" "守る"
[8,] "図る [5]" "廃止 [4]" "腐敗" "増税" "農業 [1]"
[9,] "豊か" "ひと [✓]" "主義 [✓]" "企業 [5]" "制度"
[10,] "社会 [3]" "医療" "税 [4]" "反対" "中小 [✓]"
[11,] "県" "金" "消費 [4]" "民主 [3]" "進める"
[12,] "実現" "実現" "権" "守る" "年金"
[13,] "道" "税金" "廃止 [4]" "選挙" "生活"
[14,] "建設" "兆" "企業 [5]" "つらぬく" "実現"
[15,] "促進" "民主 [3]" "日本 [4]" "基地" "医療"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=15)[[1]]
[[2]]
model <- create_model(docs, seed_list, extra_k=25)
res <- topicdict_train(model, iter = iter_num)
count_appearence_list(tidy_seededlda_out(model, res, show=T), seed_list) 1 2 3 4 5
[1,] "整備 [✓]" "増税" "政治 [✓]" "日本 [✓]" "企業 [✓]"
[2,] "道路" "党" "税 [4]" "党" "教育 [✓]"
[3,] "地域" "憲法" "消費 [4]" "国民 [✓]" "図る [✓]"
[4,] "振興" "守る" "社会 [✓]" "共産" "充実 [✓]"
[5,] "道" "円" "自由 [✓]" "税 [✓]" "中小 [✓]"
[6,] "産業 [✓]" "自民党" "国民 [4]" "政治 [3]" "対策"
[7,] "県" "反対" "自民党" "消費 [✓]" "福祉"
[8,] "漁業 [✓]" "民営" "廃止 [4]" "増税" "制度"
[9,] "促進" "税 [4]" "金" "民主 [3]" "確立"
[10,] "建設" "民主 [3]" "選挙" "守る" "安定"
[11,] "早期" "庶民" "実現" "反対" "農業 [1]"
[12,] "交通" "改悪" "民主 [✓]" "選挙" "実現"
[13,] "農業 [✓]" "郵政" "権" "自民党" "進める"
[14,] "推進" "日本 [4]" "党" "主人公" "つとめる"
[15,] "農林" "野党" "院" "企業 [5]" "社会 [3]"
[[1]]
[[2]]
[[3]]
get_lda_result(doc_folder, seed_list, iter_num, k=30)[[1]]
[[2]]
diagnosis_topic_recovery_heatmap <- function(post, n=25, title_=T,
seed_list=NULL,
topicvec=c(), merge=list()){
topwords <- top_terms(post, n=n)
topwords <- data.frame(topwords)
colnames(topwords) <- paste0("EstTopic", 1:ncol(topwords))
topwords <- tidyr::gather(topwords, key=EstTopic, value=Word) %>%
mutate(Word = gsub("\\s.*$", "", Word))
topwords %>%
mutate(RawWord = Word) %>%
tidyr::separate(Word,
into=c("word_id", "TrueTopic"),
sep="t") %>%
mutate(TrueTopic = paste0("True", as.character(TrueTopic))) -> res_
merge_length <- length(merge)
if(merge_length != 0){
# Merge Topics
for(i in 1:merge_length){
m <- merge[[i]]
mt <- paste0("True", m)
res_ %>%
mutate(TrueTopic=replace(TrueTopic, TrueTopic==mt[1], mt[3])) %>%
mutate(TrueTopic=replace(TrueTopic, TrueTopic==mt[2], mt[3])) -> res_
}
}
res_ %>%
group_by(EstTopic, TrueTopic) %>%
summarise(counts = n()) %>%
ungroup() %>%
group_by(EstTopic) %>%
mutate(topicsum = sum(counts)) %>%
ungroup() %>%
mutate(Proportion = counts / topicsum * 100) -> res_
if(!is.null(seed_list)){
# Use only topics with keywords
num <- length(seed_list)
seed_list_name <- paste0("EstTopic", 1:num)
res_ %>%
filter(EstTopic %in% get("seed_list_name")) -> res_
}
num <- length(unique(res_$EstTopic))
if(is.null(topicvec)){
res_ %>%
group_by(EstTopic) %>%
top_n(1, Proportion) %>%
mutate(forranking = as.integer(gsub("EstTopic", "", EstTopic))) %>%
arrange(forranking) %>%
select(EstTopic) -> topicvec
topicvec <- unique(as.integer(gsub("EstTopic", "", topicvec$EstTopic)))
}else if(length(topicvec) != num){
message("topicvec length does not match")
topicvec <- 1:num
}
truenum <- length(unique(res_$TrueTopic))
title <- paste0("Seeded LDA: Top ", as.character(n), " words")
g <- ggplot(res_, aes(EstTopic, TrueTopic)) +
geom_tile(aes(fill=Proportion)) +
scale_fill_gradient(limits=c(0, 100), low="#e8e8e8", high="#0072B2", name = "Proportion") +
scale_x_discrete(limits = rev(paste0("EstTopic", topicvec))) +
coord_flip() +
scale_y_discrete(limits = paste0("True", 1:truenum)) +
xlab("Estimated Topics") + ylab("True Topic") + theme_bw(base_size=13)
if(title_){
g <- g + ggtitle(title) +
theme(plot.title = element_text(hjust = 0.5))
}
return(g)
}library(grid)
library(gridExtra)
run_simulations <- function(trueK, estimatedK, seeds_len=6,
seed_only=F, seed_contamination=0){
# Create Combinations
combinations <- expand.grid(trueK, estimatedK)
num_combinations <- nrow(combinations)
# Run Simulations
for(s in 1:num_combinations){
trueK_ <- combinations[s, 1]
estimatedK_ <- combinations[s, 2]
# Create Data
set.seed(225)
data_folder <- tempfile()
seed_list <- create_sim_data(saveDir=paste0(data_folder, "Sim1"),
D=1000, K=trueK_, TotalV=3000, alpha=0.1,
beta_r=0.1, beta_s=0.1,
p=rep(0.15, trueK_) + rnorm(trueK_, mean=0, sd=0.04),
lambda=200, seeds_len=seeds_len)
seed_list <- lapply(seed_list, function(x){tolower(x)})
# Seed contamination
if(seed_contamination != 0){
for(i in 1:seed_contamination){
seed_list <- lapply(seed_list, function(x){
x[sample(1:seeds_len, 1)] <- seed_list[[sample(1:trueK_, 1)]][sample(1:seeds_len, 1)]
return(x)
})
}
}
# Fit the model
extra_k_ <- estimatedK_ - length(seed_list)
if(extra_k_ < 0){
message("extra_k is negative, setting it to 0")
extra_k_ <- 0
seed_list <- seed_list[1:estimatedK_]
}
doc_folder <- paste0(data_folder, "Sim1", "/W")
docs <- list.files(doc_folder, pattern = "*.txt", full.names = TRUE)
model <- create_model(docs, seed_list, extra_k=extra_k_)
res <- topicdict_train(model, iter = iter_num)
post <- topicdict::posterior(res)
if(seed_only){
g <- diagnosis_topic_recovery_heatmap(post, 15, title_=F, seed_list=seed_list)
}else{
g <- diagnosis_topic_recovery_heatmap(post, 15, title_=F)
}
# Save
saveRDS(g, file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/",
"fig_T", trueK_, "_E", estimatedK_, ".obj"))
saveRDS(post, file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/",
"post_T", trueK_, "_E", estimatedK_, ".obj"))
message(paste0("Done: ", s, "/", num_combinations))
}
}
create_simulation_figure <- function(trueK, estimatedK, title_="Simulation Results"){
# Create Combinations
combinations <- expand.grid(trueK, estimatedK) %>%
arrange(rev(Var2))
num_combinations <- nrow(combinations)
# Load Data
figures <- list()
for(s in 1:num_combinations){
trueK_ <- combinations[s, 1]
estimatedK_ <- combinations[s, 2]
figures[[s]] <- readRDS(file = paste0("/Users/Shusei/Dropbox/Study/Project/ImaiText/topicdict/vignettes/obj/",
"fig_T", trueK_, "_E", estimatedK_, ".obj"))
}
### Create Figure
# Get Information
g1 <- ggplotGrob(figures[[1]])
id.legend <- grep("guide", g1$layout$name)
legend <- g1[["grobs"]][[id.legend]]
# Edit Figure
edit_figure <- theme(legend.position="none",
axis.title.x=element_blank(),
axis.title.y=element_blank(),
axis.text.x=element_blank(),
axis.text.y=element_blank())
figures <- lapply(figures, function(x){x + edit_figure})
# New Pictures cf. https://stackoverflow.com/a/11093069/4357279
g <- arrangeGrob(grobs=figures,
nrow = length(estimatedK),
right = legend,
top = textGrob(title_),
left = textGrob("Estimated Topic", rot = 90, vjust = 1),
bottom = textGrob("True Topic", vjust = -0.1))
grid.draw(g) # Show plot
}
multiple_simulations <- function(trueK,
estimatedK,
seed_only=F,
seed_contamination=0){
# Run Simulation
run_simulations(trueK, estimatedK, seed_only=seed_only, seed_contamination=seed_contamination)
# Create Figure
create_simulation_figure(trueK, estimatedK)
}# How many "true" topic can keywords collect?
multiple_simulations(trueK=c(5,15,25,35), estimatedK=c(5,15,25,35))[1] "Finished: "
user system elapsed
10.108 1.706 11.849
[1] "Finished: "
user system elapsed
8.461 0.996 9.494
[1] "Finished: "
user system elapsed
7.844 1.054 8.928
[1] "Finished: "
user system elapsed
7.460 0.914 8.393
[1] "Finished: "
user system elapsed
10.134 1.666 11.820
[1] "Finished: "
user system elapsed
7.605 1.090 8.710
[1] "Finished: "
user system elapsed
7.720 1.006 8.743
[1] "Finished: "
user system elapsed
7.849 1.024 8.891
[1] "Finished: "
user system elapsed
10.019 1.698 11.738
[1] "Finished: "
user system elapsed
8.187 1.069 9.278
[1] "Finished: "
user system elapsed
7.873 1.053 8.953
[1] "Finished: "
user system elapsed
7.443 0.848 8.312
[1] "Finished: "
user system elapsed
10.060 1.699 11.780
[1] "Finished: "
user system elapsed
8.053 1.181 9.252
[1] "Finished: "
user system elapsed
7.707 0.921 8.650
[1] "Finished: "
user system elapsed
7.707 1.026 8.752
# How many "true" topic can keywords collect?
multiple_simulations(trueK=c(5,15,25,35), estimatedK=c(5,15,25,35),
seed_contamination=2)[1] "Finished: "
user system elapsed
10.447 1.823 12.291
[1] "Finished: "
user system elapsed
8.428 1.089 9.535
[1] "Finished: "
user system elapsed
7.633 0.932 8.587
[1] "Finished: "
user system elapsed
7.663 0.951 8.630
[1] "Finished: "
user system elapsed
10.217 1.781 12.019
[1] "Finished: "
user system elapsed
8.067 1.106 9.195
[1] "Finished: "
user system elapsed
7.929 1.091 9.037
[1] "Finished: "
user system elapsed
7.391 1.046 8.458
[1] "Finished: "
user system elapsed
10.192 1.758 11.970
[1] "Finished: "
user system elapsed
8.415 1.135 9.568
[1] "Finished: "
user system elapsed
7.730 1.080 8.832
[1] "Finished: "
user system elapsed
7.512 1.067 8.597
[1] "Finished: "
user system elapsed
10.083 1.825 11.933
[1] "Finished: "
user system elapsed
7.948 1.191 9.157
[1] "Finished: "
user system elapsed
7.586 1.004 8.607
[1] "Finished: "
user system elapsed
7.423 1.019 8.459